import torch
from torch import nn


class RMSNormal(nn.Module):

    def __init__(self, input_dim):
        super().__init__()
        self._w = nn.Parameter(torch.randn(input_dim))

    def forward(self, x):
        return self._w * x / ((x**2).sum() ** 0.5 + 1e-6)
